{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.database import get_protocol, FileFinder\n",
    "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
    "                        preprocessors={\"audio\": FileFinder()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and apply a voice activity detection model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.tasks import VoiceActivityDetection\n",
    "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vad = VoiceActivityDetection(protocol, duration=2., batch_size=16, num_workers=4)\n",
    "model = SimpleSegmentationModel(task=vad)\n",
    "trainer = pl.Trainer(max_epochs=1, default_root_dir='inference/vad')\n",
    "_ = trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio import Inference\n",
    "inference = Inference(model, step=0.1, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference\n",
    "dev_file = next(protocol.development())\n",
    "scores = inference(dev_file)\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on an excerpt \n",
    "from pyannote.core import Segment\n",
    "scores = inference.crop(dev_file, Segment(10, 15))\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on an excerpt shorter than sliding window duration\n",
    "scores = inference.crop(dev_file, Segment(10, 11.5))\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on a whole chunk\n",
    "inference = Inference(model, window=\"whole\")\n",
    "scores = inference.crop(dev_file, Segment(10, 15))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference with progress bar\n",
    "inference = Inference(model, step=0.1, batch_size=32, progress_hook='Processing...')\n",
    "scores = inference(dev_file)\n",
    "inference = Inference(model, step=0.1, batch_size=1, progress_hook=True)\n",
    "scores = inference.crop(dev_file, Segment(10, 15))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and apply a pretrained VAD model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio import Model\n",
    "model = Model.from_pretrained('inference/vad/lightning_logs/version_0/checkpoints/epoch=0-step=8.ckpt')\n",
    "inference = Inference(model, step=0.1, batch_size=128)\n",
    "scores = inference(dev_file)\n",
    "scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and apply a speaker embedding model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyannote.audio.tasks import SpeakerEmbedding\n",
    "emb = SpeakerEmbedding(protocol, duration=2., num_workers=4)\n",
    "from pyannote.audio.models.embedding.debug import SimpleEmbeddingModel\n",
    "model = SimpleEmbeddingModel(task=emb)\n",
    "trainer = pl.Trainer(max_epochs=1, default_root_dir='inference/emb')\n",
    "_ = trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference using a sliding window\n",
    "inference = Inference(model, duration=1., step=0.5)\n",
    "embeddings = inference(dev_file)\n",
    "\n",
    "data, window = embeddings.data, embeddings.sliding_window\n",
    "data.shape, window.start, window.duration, window.step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference using a sliding window on an excerpt\n",
    "embeddings = inference.crop(dev_file, Segment(5, 12))\n",
    "\n",
    "data, window = embeddings.data, embeddings.sliding_window\n",
    "data.shape, window.start, window.duration, window.step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference using a sliding window on an excerpt shorter than sliding window\n",
    "embeddings = inference.crop(dev_file, Segment(11.1, 12))\n",
    "\n",
    "data, window = embeddings.data, embeddings.sliding_window\n",
    "data.shape, window.start, window.duration, window.step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on a whole chunk\n",
    "inference = Inference(model, window=\"whole\")\n",
    "embeddings = inference.crop(dev_file, Segment(5, 12))\n",
    "\n",
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on a whole chunk shorter than training duration\n",
    "embeddings = inference.crop(dev_file, Segment(5, 5.2))\n",
    "\n",
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on a collection of chunks\n",
    "embeddings = inference.crop(dev_file, [Segment(0, 4), Segment(5, 12)])\n",
    "\n",
    "embeddings.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
